import torch
from torch import nn
from transformers.models.deberta_v2.modeling_deberta_v2 import DebertaV2PreTrainedModel, DebertaV2Model
from transformers.models.bert.modeling_bert import BertPooler
from models.camib import CaMIB
from torch.nn import MSELoss
import random
import configs
from configs import DEVICE


class DeBertaForSequenceClassification(DebertaV2PreTrainedModel):
    def __init__(self, config, multimodal_config):
        super().__init__(config)
        TEXT_DIM, VISUAL_DIM, ACOUSTIC_DIM = (
            configs.TEXT_DIM,
            configs.VISUAL_DIM,
            configs.ACOUSTIC_DIM
        )
        self.config = config
        self.num_labels = config.num_labels
        self.pooler = BertPooler(config)
        model = DebertaV2Model.from_pretrained(multimodal_config.model)
        self.model = model.to(DEVICE)
        
        hidden_dim = multimodal_config.hidden_dim
        output_dim = multimodal_config.output_dim

        self.project_t = nn.Linear(TEXT_DIM, hidden_dim)
        self.project_v = nn.Linear(VISUAL_DIM, hidden_dim)
        self.project_a = nn.Linear(ACOUSTIC_DIM, hidden_dim)  

        self.expand_t = nn.Linear(output_dim, TEXT_DIM)
        self.expand_v = nn.Linear(output_dim, TEXT_DIM)
        self.expand_a = nn.Linear(output_dim, TEXT_DIM)
        #
        self.expand_c = nn.Linear(output_dim, TEXT_DIM)
        self.expand_o = nn.Linear(output_dim, TEXT_DIM)

        self.LayerNorm_t = nn.LayerNorm(config.hidden_size)
        self.LayerNorm_v = nn.LayerNorm(config.hidden_size)
        self.LayerNorm_a = nn.LayerNorm(config.hidden_size)
        #
        self.LayerNorm_c = nn.LayerNorm(config.hidden_size)
        self.LayerNorm_o = nn.LayerNorm(config.hidden_size)
        self.LayerNorm_co = nn.LayerNorm(config.hidden_size)

        self.classifier_t = nn.Linear(config.hidden_size, config.num_labels)
        self.classifier_v = nn.Linear(config.hidden_size, config.num_labels)
        self.classifier_a = nn.Linear(config.hidden_size, config.num_labels)
        #
        self.classifier_c = nn.Linear(config.hidden_size, config.num_labels)
        self.classifier_o = nn.Linear(config.hidden_size, config.num_labels)
        self.classifier_co = nn.Linear(config.hidden_size, config.num_labels)
                
        self.Fusion = CaMIB(
            input_dim=hidden_dim,
            dim=multimodal_config.dim,
            output_dim=output_dim,
            beta=multimodal_config.p_beta
        )
        self.mse_loss = nn.MSELoss()
        self.dropout = nn.Dropout(multimodal_config.dropout_prob)
        self.init_weights()
    
    def random_op(self, xc, xo):
        num = xc.shape[0]
        l = [i for i in range(num)]
        random.shuffle(l)
        random_idx = torch.tensor(l)
        xco = xc[random_idx] + xo
        return xco

    def forward(
            self,
            input_ids, # (batch_size, max_seq_length)
            visual,
            acoustic,
            label_ids
    ):
        embedding_output = self.model(input_ids)  
        x = embedding_output[0] # (batch_size, max_seq_length, hidden_dim)         
        text = self.project_t(x) 
        visual = self.project_v(visual)
        acoustic = self.project_a(acoustic)
        
        output_t, output_v, output_a, output_c, output_o, align_loss, KL = self.Fusion(text, visual, acoustic)

        h_t = self.expand_t(output_t)
        h_v = self.expand_v(output_v)
        h_a = self.expand_a(output_a)

        sequence_output_t = self.LayerNorm_t(h_t)
        sequence_output_v = self.LayerNorm_v(h_v)
        sequence_output_a = self.LayerNorm_a(h_a)

        pooled_output_t = sequence_output_t.mean(dim=1)
        pooled_output_v = sequence_output_v.mean(dim=1) 
        pooled_output_a = sequence_output_a.mean(dim=1) 

        logits_t = self.classifier_t(pooled_output_t)
        logits_v = self.classifier_v(pooled_output_v)
        logits_a = self.classifier_a(pooled_output_a)
        
        loss_fct = MSELoss()
        mse_loss_t = loss_fct(logits_t.view(-1), label_ids.view(-1))
        mse_loss_v = loss_fct(logits_v.view(-1), label_ids.view(-1))
        mse_loss_a = loss_fct(logits_a.view(-1), label_ids.view(-1))
        unimodal_mse_loss = mse_loss_t + mse_loss_v + mse_loss_a
        
        # c: trivial    o: causal
        h_c = self.expand_c(output_c)
        sequence_output_c = self.LayerNorm_c(h_c)
        pooled_output_c = self.pooler(sequence_output_c) 
        pooled_output_c = self.dropout(pooled_output_c)
        logits_c = self.classifier_c(pooled_output_c)
        
        h_o = self.expand_o(output_o)
        sequence_output_o = self.LayerNorm_o(h_o)
        pooled_output_o = self.pooler(sequence_output_o) 
        pooled_output_o = self.dropout(pooled_output_o)
        logits_o = self.classifier_o(pooled_output_o)

        h_co = self.random_op(h_c, h_o)
        sequence_output_co = self.LayerNorm_co(h_co)
        pooled_output_co = self.pooler(sequence_output_co) 
        pooled_output_co = self.dropout(pooled_output_co) 
        logits_co = self.classifier_co(pooled_output_co)

        mu_trivial = logits_c             
        sigma_trivial = torch.ones_like(logits_c) * 3.0 
        mu_target = torch.zeros_like(logits_c)          
        sigma_target = torch.sqrt(torch.tensor(3.0))     
        kl_loss = torch.log(sigma_target / sigma_trivial) + \
          (sigma_trivial**2 + (mu_trivial - mu_target)**2) / (2 * sigma_target**2) - 0.5
        loss_c = kl_loss.mean()

        return logits_o, logits_co, loss_c, unimodal_mse_loss, align_loss, KL



